/* Copyright 2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* This file is based on object detection sample ssd example in
* https://github.com/openvinotoolkit/openvino/
*/

#include <iostream>
#include <string>
#include <memory>
#include <vector>
#include <algorithm>
#include <map>

#include <format_reader_ptr.h>
#include <inference_engine.hpp>
#include <ext_list.hpp>

#include <samples/common.hpp>
#include <samples/slog.hpp>

#include "ssd_od.h"

// Thickness of a line (in pixels) to be used for bounding boxes.
#define BBOX_THICKNESS 2

/**
*  Inference Engine load model.
*
*  @param modelFile Model file full path.
*  @return 0 on success. other values indicate failure.
*/
int SSDObjectDetection::load_model(const std::string &modelFile) {
    try {        
        slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;

        // Load inference engine
        slog::info << "Loading Inference Engine" << slog::endl;
        Core ie;
        std::string device = "CPU";

        slog::info << "Device info: " << slog::endl;
        std::cout << ie.GetVersions("CPU");        

        /**
        * cpu_extensions library is compiled from "extension" folder containing
        * custom MKLDNNPlugin layer implementations. These layers are not supported
        * by mkldnn, but they can be useful for inferring custom topologies.
        **/
        ie.AddExtension(std::make_shared<Extensions::Cpu::CpuExtensions>(),
                        "CPU");
        // -------------------------------------------------------------------

        // Read IR Generated by ModelOptimizer (.xml and .bin files)
        std::string binFileName = fileNameNoExt(modelFile) + ".bin";
        slog::info << "Loading network files:"
            "\n\t" << modelFile <<
            "\n\t" << binFileName <<
            slog::endl;

        CNNNetReader networkReader;
        // Read network model.
        networkReader.ReadNetwork(modelFile);

        // Extract model name and load weights.
        networkReader.ReadWeights(binFileName);
        network = networkReader.getNetwork();
        // -------------------------------------------------------------------

        // Prepare input blobs
        slog::info << "Preparing input blobs" << slog::endl;

        // Taking information about all topology inputs.
        inputsInfo = network.getInputsInfo();

        // SSD network has one input and one output.
        if (inputsInfo.size() != 1 && inputsInfo.size() != 2) {
            throw std::logic_error("Supports topologies only with 1 or 2 inputs");
        }

        /**
         * Some networks have SSD-like output format (ending with DetectionOutput layer), but
         * having 2 inputs as Faster-RCNN: one for image and one for "image info".
         *
         * Although ssd object detection main task is to support clean SSD, it could score
         * the networks with two inputs as well. 
         * For such networks imInfoInputName will contain the "second" input name.
         */
        inputInfo = nullptr;

        SizeVector inputImageDims;
        // Iterating over all input blobs.
        for (auto & item : inputsInfo) {
            // Working with first input tensor that stores image.
            // TODO: Remove hardcoding of numbers and assign meaningful variable names.
            if (item.second->getInputData()->getTensorDesc().getDims().size() == 4) {
                imageInputName = item.first;

                inputInfo = item.second;

                slog::info << "Batch size is "
                    << std::to_string(networkReader.getNetwork().getBatchSize())
                    << slog::endl;

                // Creating first input blob.
                Precision inputPrecision = Precision::U8;
                item.second->setPrecision(inputPrecision);
            } else if (item.second->getInputData()->getTensorDesc().getDims().size() == 2) {
                imInfoInputName = item.first;

                Precision inputPrecision = Precision::FP32;
                item.second->setPrecision(inputPrecision);
                if ((item.second->getTensorDesc().getDims()[1] != 3
                     && item.second->getTensorDesc().getDims()[1] != 6)) {
                    throw std::logic_error("Invalid input info. Should be 3 or 6 values length");
                }
            }
        }

        if (inputInfo == nullptr) {
            inputInfo = inputsInfo.begin()->second;
        }
        // -------------------------------------------------------------------

        // Prepare output blobs.
        slog::info << "Preparing output blobs" << slog::endl;

        OutputsDataMap outputsInfo(network.getOutputsInfo());
       
        DataPtr outputInfo;
        for (const auto& out : outputsInfo) {
            if (out.second->getCreatorLayer().lock()->type == "DetectionOutput") {
                outputName = out.first;
                outputInfo = out.second;
            }
        }

        if (outputInfo == nullptr) {
            throw std::logic_error("Can't find a DetectionOutput layer in the topology");
        }

        const SizeVector outputDims = outputInfo->getTensorDesc().getDims();

        maxProposalCount = outputDims[2];
        objectSize = outputDims[3];

        if (objectSize != 7) {
            throw std::logic_error("Output item should have 7 as a last dimension");
        }

        if (outputDims.size() != 4) {
            throw std::logic_error("Incorrect output dimensions for SSD model");
        }

        /** Set the precision of output data provided by the user.
         *  should be called before load of the network to the device
        **/
        outputInfo->setPrecision(Precision::FP32);
        // -------------------------------------------------------------------

        // Loading model to the device.
        slog::info << "Loading model to the device" << slog::endl;

        executable_network = ie.LoadNetwork(network, "CPU", {});
    }
    catch (const std::exception& error) {
        slog::err << error.what() << slog::endl;
        return 1;
    }
    catch (...) {
        slog::err << "Unknown/internal exception happened." << slog::endl;
        return 1;
    }

    slog::info << "Load Model successful" << slog::endl;
    return 0;
}


/**
*  Inference Engine object detection function.
*
*  @param imageFileIn Input image full path.
*  @param imageFileOut Output image full path.
*  @return 0 on success. other values indicate failure.
*/
int SSDObjectDetection::object_detection(const std::string &imageFileIn, 
                                         const std::string &imageFileOut) {
    try {

        // Read input.
        // This vector stores paths to the processed images.
        std::vector<std::string> images;
        images.push_back(imageFileIn);
        if (images.empty()) throw std::logic_error("No suitable images were found");
        // -------------------------------------------------------------------

        // Create infer request
        slog::info << "Create infer request" << slog::endl;
        InferRequest infer_request = executable_network.CreateInferRequest();
        // -------------------------------------------------------------------

        // Prepare input.
        // Collect images data ptrs.
        std::vector<std::shared_ptr<unsigned char>> imagesData, originalImagesData;
        std::vector<size_t> imageWidths, imageHeights;
        for (auto & i : images) {
            FormatReader::ReaderPtr reader(i.c_str());
            if (reader.get() == nullptr) {
                slog::warn << "Image " + i + " cannot be read!" << slog::endl;
                continue;
            }
            // Store image data.
            std::shared_ptr<unsigned char> originalData(reader->getData());
            auto dims = inputInfo->getTensorDesc().getDims();
            std::shared_ptr<unsigned char> data(reader->getData(dims[3],dims[2]));
                                                
            if (data.get() != nullptr) {
                originalImagesData.push_back(originalData);
                imagesData.push_back(data);
                imageWidths.push_back(reader->width());
                imageHeights.push_back(reader->height());
            }
        }
        if (imagesData.empty()) {
            throw std::logic_error("Valid input images were not found!");
        }

        size_t batchSize = network.getBatchSize();
        slog::info << "Batch size is " << std::to_string(batchSize) << slog::endl;
        if (batchSize != imagesData.size()) {
            slog::warn << "Number of images " + std::to_string(imagesData.size()) + \
                " doesn't match batch size " + std::to_string(batchSize) << slog::endl;
            batchSize = std::min(batchSize, imagesData.size());
            slog::warn << "Number of images to be processed is "
                << std::to_string(batchSize) << slog::endl;
        }

        // Creating input blob.
        Blob::Ptr imageInput = infer_request.GetBlob(imageInputName);

        // Filling input tensor with images. First b channel, then g and r channels.
        auto image_dims = imageInput->getTensorDesc().getDims();
        size_t num_channels = image_dims[1];
        size_t image_size = image_dims[3] * image_dims[2];

        unsigned char* data = static_cast<unsigned char*>(imageInput->buffer());

        // Iterate over all input images.
        for (size_t image_id = 0; image_id < std::min(imagesData.size(), batchSize); ++image_id) {
            // Iterate over all pixel in image (b,g,r).
            for (size_t pid = 0; pid < image_size; pid++) {
                // Iterate over all channels.
                for (size_t ch = 0; ch < num_channels; ++ch) {
                    /** [images stride + channels stride + pixel id ] all in bytes **/
                    data[image_id * image_size * num_channels + ch * image_size + pid] = \
                            imagesData.at(image_id).get()[pid*num_channels + ch];
                }
            }
        }

        if (imInfoInputName != "") {
            Blob::Ptr input2 = infer_request.GetBlob(imInfoInputName);
            auto imInfoDim = inputsInfo.find(imInfoInputName)->second->getTensorDesc().getDims()[1];

            // Fill input tensor with values.
            float *p = input2->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();

            for (size_t image_id = 0; image_id < std::min(imagesData.size(), batchSize); ++image_id) {
                auto dims = inputsInfo[imageInputName]->getTensorDesc().getDims();
                p[image_id * imInfoDim + 0] = static_cast<float>(dims[2]);
                p[image_id * imInfoDim + 1] = static_cast<float>(dims[3]);
                for (size_t k = 2; k < imInfoDim; k++) {
                    p[image_id * imInfoDim + k] = 1.0f;  // all scale factors are set to 1.0
                }
            }
        }
        // -------------------------------------------------------------------

        // Do inference
        slog::info << "Start inference" << slog::endl;
        infer_request.Infer();
        // -------------------------------------------------------------------

        // Process output
        slog::info << "Processing output blobs" << slog::endl;

        const Blob::Ptr output_blob = infer_request.GetBlob(outputName);
        const float* detection = static_cast<PrecisionTrait<Precision::FP32>::value_type*>(output_blob->buffer());

        std::vector<std::vector<int> > boxes(batchSize);
        std::vector<std::vector<int> > classes(batchSize);

        // Each detection has image_id that denotes processed image.
        for (int curProposal = 0; curProposal < maxProposalCount; curProposal++) {
            auto image_id = static_cast<int>(detection[curProposal * objectSize + 0]);
            if (image_id < 0) {
                break;
            }

            float confidence = detection[curProposal * objectSize + 2];
            auto label = static_cast<int>(detection[curProposal * objectSize + 1]);
            auto xmin = static_cast<int>(detection[curProposal * objectSize + 3] * imageWidths[image_id]);
            auto ymin = static_cast<int>(detection[curProposal * objectSize + 4] * imageHeights[image_id]);
            auto xmax = static_cast<int>(detection[curProposal * objectSize + 5] * imageWidths[image_id]);
            auto ymax = static_cast<int>(detection[curProposal * objectSize + 6] * imageHeights[image_id]);

            if (confidence > 0.5) {
                // Drawing only objects with >50% probability.
                classes[image_id].push_back(label);
                boxes[image_id].push_back(xmin);
                boxes[image_id].push_back(ymin);
                boxes[image_id].push_back(xmax - xmin);
                boxes[image_id].push_back(ymax - ymin);
                std::cout << "[" << curProposal << "," << label << "] element, prob = " << confidence <<
                "    (" << xmin << "," << ymin << ")-(" << xmax << "," << ymax << ")" << " batch id : " << image_id;
                std::cout << std::endl;
            }
        }

        for (size_t batch_id = 0; batch_id < batchSize; ++batch_id) {
            addRectangles(originalImagesData[batch_id].get(),
                          imageHeights[batch_id], imageWidths[batch_id],
                          boxes[batch_id], classes[batch_id],
                          BBOX_THICKNESS);
            if (writeOutputBmp(imageFileOut, originalImagesData[batch_id].get(),
                               imageHeights[batch_id], imageWidths[batch_id])) {
                slog::info << "Image " + imageFileOut + " created!" << slog::endl;
            } else {
                throw std::logic_error(std::string("Can't create a file: ") + imageFileOut);
            }
        }
    }
    catch (const std::exception& error) {
        slog::err << error.what() << slog::endl;
        return 1;
    }
    catch (...) {
        slog::err << "Unknown/internal exception happened." << slog::endl;
        return 1;
    }

    slog::info << "Execution successful" << slog::endl;
    return 0;
}
